# src/vol5_k2m_cc/cc_targetcov.py
"""
Adaptive coverage selection helpers for the compact-curvature translator.

All percentages in this file are in *percent units*, e.g. 0.5 means 0.5%.
"""

from __future__ import annotations
import numpy as np
from typing import Tuple


def _pct_to_count(pct: float, total: int) -> int:
    """Convert percent (e.g., 0.5 for 0.5%) to a count in [1..total]."""
    if pct <= 0:
        return 0
    k = int(round((pct / 100.0) * total))
    return max(1, min(total, k))


def topk_mask_by_score(score: np.ndarray, k: int) -> np.ndarray:
    """
    Return a boolean mask selecting the top-k pixels by 'score'.
    Deterministic tie-break using <= threshold of the kth value, with
    an additional trim if we overshoot (rare for heavy ties).
    """
    flat = score.ravel()
    if k <= 0:
        return np.zeros_like(score, dtype=bool)
    if k >= flat.size:
        return np.ones_like(score, dtype=bool)

    # kth largest -> partition by index in ascending order of -score
    kth = np.partition(flat, flat.size - k)[flat.size - k]
    m = score >= kth
    # If overshoot due to ties, trim down deterministically from edges inward
    extra = m.sum() - k
    if extra > 0:
        # Remove 'extra' pixels with score == kth using a stable order
        idx = np.flatnonzero((score == kth).ravel())
        if idx.size >= extra:
            m_flat = m.ravel()
            m_flat[idx[:extra]] = False
            m = m_flat.reshape(score.shape)
    return m


def mask_with_target_coverage(
    pos_score: np.ndarray,
    target_coverage_pct: float,
    min_coverage_pct: float = 0.0,
) -> Tuple[np.ndarray, dict]:
    """
    Select a mask from a *positive* score map achieving ~target coverage.

    If 'pos_score' has no positive entries, returns an empty mask.
    """
    H, W = pos_score.shape
    total = H * W

    info = {
        "strategy": "target_coverage",
        "target_coverage_pct": float(target_coverage_pct),
        "min_coverage_pct": float(min_coverage_pct),
        "achieved_coverage_pct": 0.0,
    }

    # Only use strictly-positive support for quantiles
    pos = pos_score[pos_score > 0]
    if pos.size == 0:
        mask = np.zeros((H, W), dtype=bool)
        return mask, info

    k = _pct_to_count(target_coverage_pct, total)
    if k == 0:
        mask = np.zeros((H, W), dtype=bool)
        return mask, info

    # Pick top-k by positive score
    mask = topk_mask_by_score(pos_score, k)
    info["achieved_coverage_pct"] = 100.0 * (mask.sum() / float(total))

    # If we somehow landed below min_coverage, bump to min_coverage explicitly
    kmin = _pct_to_count(min_coverage_pct, total)
    if mask.sum() < kmin:
        mask = topk_mask_by_score(pos_score, kmin)
        info["achieved_coverage_pct"] = 100.0 * (mask.sum() / float(total))
        info["strategy"] = "bumped_to_min"

    return mask, info
